import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR
import torch.optim as optim
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import math

if __name__ == '__main__':

    brain_train_data = 1350
    adver_train_data = int(51137 * 0.88)
    yecaichong_data = 994  # train 994 val 109
    lr = 1e-4
    mode = 'cosineAnn'
    max_epoch = 25
    batch_size = 64
    ACCUMULATE = 2
    iters = math.ceil(adver_train_data / batch_size)
    T = iters // ACCUMULATE * max_epoch  # cycle
    print(iters)

    model = resnet18(pretrained=False)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    if mode == 'cosineAnn':
        scheduler = CosineAnnealingLR(optimizer, T_max=T)
    elif mode == 'cosineAnnWarm':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2)
        '''
        以T_0=5, T_mult=1为例:
        T_0:学习率第一次回到初始值的epoch位置.
        T_mult:这个控制了学习率回升的速度
            - 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
                - 5,10,15,20,25,.......处回到最大值
            - 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
                - 5,15,35,75,155,.......处回到最大值
        example:
            T_0=5, T_mult=1
        '''
    plt.figure()
    cur_lr_list = []
    for epoch in range(max_epoch):
        model.train()
        print('epoch_{}'.format(epoch))
        for batch in range(iters):
            if (batch + 1) % ACCUMULATE == 0:  # Gradient Accumulate
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
            # scheduler.step(epoch + batch / iters)
            cur_lr = optimizer.param_groups[-1]['lr']
            cur_lr_list.append(cur_lr)
            # print('cur_lr:', cur_lr)
        print('epoch_{}_end'.format(epoch))
        # scheduler.step()
    x_list = list(range(len(cur_lr_list)))
    plt.plot(x_list, cur_lr_list)
    plt.show()
